Simple Non-Linear Setting

Non-Linear Model with Binary Treatment and Continuous Covariates

Authors
Affiliations

Nolan Cole

Department of Biostatistics, University of Washington

Lars van der Laan

Department of Statistics, University of Washington

Marco Carone

Department of Biostatistics, University of Washington

Department of Statistics, University of Washington

Published

February 20, 2026

Simulation Details


Data Generating Process

\begin{align*} W_1 &\sim \mathrm{Unif}(0,1), \\ W_2 &\sim \mathrm{Unif}(-1,1), \\ W_3 &\sim \mathrm{Unif}(-5,5), \\ W &= (W_1, W_2, W_3)^\top, \\ V &= W_1, \\ \varepsilon &\sim \mathrm{N}(0,\sigma^2), \\ \pi_0(W) &= P(A=1 \mid W) = \mathrm{logit}^{-1}\!\big(\gamma_0 + \gamma_1 W_1 + \gamma_2 W_2 + \gamma_3 W_3\big), \\ \tau_0(W) &= \tfrac{1}{2}\sin(2\pi W_1) + \tfrac{1}{2}\cos(\pi W_2), \\ \mu_{00}(W) &= -2 + 0.5\,W_1 - 0.25\,W_2 + 0.1\,W_3, \\ \mu_{01}(W) &= \mu_{00}(W) + \tau_0(W), \\ Y &= \mu_{0A}(W) + \varepsilon, \\ \tau(W) &= \mathbb{E}[Y \mid A=1, W] - \mathbb{E}[Y \mid A=0, W] = \tau_0(W), \\ \bar{\tau}_0(v) &= \mathbb{E}[\tau_0(W) \mid V=v] = \tfrac{1}{2}\sin(2\pi v). \end{align*}


Simulation Parameters

Parameter Value
\gamma_0 0
\gamma_1 -0.75
\gamma_2 0.5
\gamma_3 -0.25
\sigma 0.2
n_{\text{vals}} {250,\ 500,\ 1000,\ 2500,\ 10000}
t_{\text{subset}} {-0.6,\ -0.5,\ -0.25,\ 0,\ 0.25,\ 0.5,\ 0.6}
Support bounds [-0.5,\ 0.5]
\text{ngrid} 2500

Statistical Parameters


CATE

\tau_0(w) = E_0[Y | A=1, W=w] - E_0[Y | A=0, W=w]


V-specific CATE

\bar{\tau}_0(v) = E_0[ \tau_0(W) | V=v]

V-specific CATE CDF

\overline{\theta}_0(t) = E_0[1\{\bar{\tau}_0(V) \leq t\}]

V-specific CATE Primitive CDF

\overline{\Psi}_0(t) = E_0[1\{\bar{\tau}_0(V) \leq t\} \{t-\bar{\tau}_0(V)\}]


Remainder Terms for Primitive

\begin{align*} \Psi_t(P) - \Psi_t(P_0) + P_0 D(P) & = E_{0} \bigg[ 1\{\bar{\tau}_P \leq t < \bar{\tau}_0\} (\bar{\tau}_P - \bar{\tau}_0)\bigg] \\ & \quad + E_{0} \bigg[ 1\{\bar{\tau}_P < t \leq \bar{\tau}_0\} (t-\bar{\tau}_P)\bigg]\\ & \quad - E_{0} \bigg[ 1\{\bar{\tau}_0 < t \leq \bar{\tau}_P\} (t - \bar{\tau}_0) \bigg] \\ & \quad - \frac{1}{2} E_0\left[1\{\bar{\tau}_P = t\} \left( \bar{\tau}_P-\bar{\tau}_0 \right) \right] \\ & \quad + \bar{R}_1(P_0, P) \\ & = \bar{R}. \end{align*} where we have defined \bar{R}_1(P_0, P):= - E_{0} \bigg[ \left(1\{\bar{\tau}_P(V)<t\} + \frac{1\{\bar{\tau}_P(V) = t\}}{2} \right) \bigg( ( \frac{\pi_0(W)}{\pi_P(W)} - 1) (\mu_0(1,W) - \mu_P(1,W)) - (\frac{1 - \pi_0(W)}{1 - \pi_P(W)\}} - 1) (\mu_0(0,W) - \mu_P(0,W)) \bigg) \bigg].


Nuisance Parameter Estimation

Nuisance Parameter Algorithm Purpose / Rationale
\mu_P(A,W) (outcome regression) Stratified HAL (Lrnr_stratified + Lrnr_hal9001, max_degree=2, smoothness_orders=1), stratified on A Flexible nonparametric regression with oracle-rate guarantees; stratification avoids modeling treatment interactions
\pi_P(W) (propensity score) Logistic regression (GLM with logit link), model A \sim W_1 + W_2 + W_3 Correctly specified parametric model for stability and fast convergence
\tau_P(W) (CATE) Plug-in estimator \tau_n(W)=\mu_{n,1}(W)-\mu_{n,0}(W) Standard T-learner derived from outcome regression
\bar{\tau}_P(V) (V-specific CATE) DR-learner: regress pseudo-outcome \phi_n on V using kernel regression (safe_fk_regression) One-dimensional smoothing exploits V=W_1 structure and improves efficiency
\mathrm{Var}_P(Y\mid A,W) Stratified HAL fit to squared residuals (Y-\mu_n(A,W))^2 Flexible variance estimation needed for Chernoff scaling constants
c_P(W) \displaystyle c_n(W)=\frac{\widehat{\mathrm{Var}}(Y\mid A=1,W)(1-\pi_n(W))+\widehat{\mathrm{Var}}(Y\mid A=0,W)\pi_n(W)}{\pi_n(W)(1-\pi_n(W))} Efficient influence-function variance component
\mathbb{E}[c_P(W)\mid V] Kernel regression of c_n(W) on V (safe_fk_regression) Conditional expectation in one dimension for stability
\mathrm{Var}(\tau_P(W)\mid V) Kernel regression of (\tau_n(W)-\bar{\tau}_n(V))^2 on V (safe_fk_regression) Component of \bar{c}_n(V) decomposition
\bar{c}_P(V) \bar{c}_n(V)=\mathbb{E}[c_n(W)\mid V]+\mathrm{Var}(\tau_n(W)\mid V) Chernoff scaling variance term
\mathbb{E}[\bar{c}_P(V)\mid \bar{\tau}_P(V)] HAL regression (Lrnr_hal9001, max_degree=1, smoothness_orders=0) of \bar{c}_n(V) on \bar{\tau}_n(V) Smooth relationship needed to estimate local curvature
f_{\bar{\tau}_P(V)} (density) Kernel density estimation via FKSUM::fk_density Required for Chernoff normalization constant
f_{\overline{\tau}_0}(x)

\overline{c}_0(V) = E_0[ \tfrac{Var(Y | A=1,W)(1-\pi_0(W)) + Var(Y | A=0,W)\pi_0(W)}{\pi_0(W)(1-\pi_0(W))} | V] + Var(\tau_0(W) | V)

E[ \overline{c}_0(V) | \overline{\tau}_0(V)]

\rho_0(t) = (4 f_{\bar{\tau}_0}(t)^2 E[ \overline{c}_0(V) | \overline{\tau}_0(V) = t])^{1/3}


Appendix

Code

Code
# setup
knitr::opts_chunk$set(echo = FALSE, warning = FALSE, message = FALSE, cache = FALSE)

# load libraries
if(!requireNamespace("sl3")) remotes::install_github("tlverse/sl3", force = TRUE)
if(!requireNamespace("hal9001")) remotes::install_github("tlverse/hal9001", force = TRUE)
if(!requireNamespace("pacman")) install.packages(pacman)
library(pacman)
pacman::p_load(
  tidyverse, kableExtra, knitr,
  grf, remotes, gbm,
  data.table, origami,
  ranger, xgboost, randomForest,
  gt, latex2exp, rsample,
  hal9001, sl3, plotly
)

set.seed(2025)
# results_psi <- readRDS("/Users/nolan/Library/CloudStorage/OneDrive-UW/0_Research/Carone/primitive_cate/reports/chapters/2_simple_example/6.5_vcate_wave_psi.rds")
# results_theta <- readRDS("/Users/nolan/Library/CloudStorage/OneDrive-UW/0_Research/Carone/primitive_cate/reports/chapters/2_simple_example/6.5_vcate_wave_theta.rds")
results_theta <- results_psi <- readRDS("/Users/nolan/Library/CloudStorage/OneDrive-UW/0_Research/Carone/primitive_cate/reports/chapters/2_simple_example/2.1_compare_grids.rds")

fast_qchern <- readRDS("/Users/nolan/Library/CloudStorage/OneDrive-UW/0_Research/Carone/primitive_cate/reports/chapters/2_simple_example/fast_qchern.rds")

qchernoff_fast <- function(p) {
  if (any(p < 0.001 | p > 0.999, na.rm = TRUE)) {
    stop("qchernoff_fast only defined for p in [0.001, 0.999]")
  }
  fast_qchern(p)
}
attr(qchernoff_fast, "description") <- "Approximate Chernoff quantile: qchernoff_fast(p) ~ ChernoffDist::qchernoff(p)"

lower <- -0.5
upper <- 0.5
t_subset <- c(lower-0.1, seq(lower, upper, 0.25), upper+0.1) |> sort()

t_theta_vals <- t_subset # results_theta$t |> unique()
# qs_theta <- quantile(t_theta_vals, probs = c(0.15, 0.5, 0.85))
t_theta_subset <- t_theta_vals # sapply(qs_theta, \(q) t_theta_vals[which.min(abs(t_theta_vals - q))])

t_psi_vals <- t_theta_vals # results_psi$t |> unique()
# qs_psi <- quantile(t_psi_vals, probs = c(0.15, 0.5, 0.85))
t_psi_subset <- t_psi_vals # sapply(qs_psi, \(q) t_psi_vals[which.min(abs(t_psi_vals - q))])
##########
## DGM  ##
##########

# propensity model
gamma0 <- 0
gamma1 <- -0.75
gamma2 <-  0.5
gamma3 <- -0.25

# outcome noise SD
sigma <- 0.2

param_tvals <- c(seq(lower - 0.1, upper - 0.1, 0.01), t_theta_vals) |>
  unique() |>
  sort()
param_n <- 100000 # for parameter computation

#############################
## Data Generating Process ##
#############################

## generate data ##
W1 <- runif(param_n, 0, 1)
W2 <- runif(param_n, -1, 1)
W3 <- runif(param_n, -5, 5)
W  <- cbind("W1" = W1, "W2" = W2, "W3" = W3)
V  <- W1

# Propensity score
pi0 <- plogis(gamma0 + gamma1 * W1 + gamma2 * W2 + gamma3 * W3)
A   <- rbinom(param_n, 1, pi0)

# TRUE CATE and baseline, written inline:
# tau_0(W) = 0.5 * sin(2*pi*W1) + 0.5 * cos(pi*W2)
tau0 <- 0.5 * sin(2 * pi * W1) + 0.5 * cos(pi * W2)

# m_0(W) = -2 + 0.5*W1 - 0.25*W2 + 0.1*W3
mu_00 <- -2 + 0.5 * W1 - 0.25 * W2 + 0.1 * W3     # E[Y | A=0, W]
mu_01 <- mu_00 + tau0                             # E[Y | A=1, W]
mu_0  <- mu_00 + A * tau0                         # E[Y | A, W]

# TRUE V-specific CATE
bar_tau0 <- 0.5 * sin(2 * pi * V)

# Observed outcome
Y <- mu_0 + rnorm(param_n, mean = 0, sd = sigma)
YAW <- data.frame(Y = Y, A = A, W)
Y1W <- data.frame(Y = Y, A = 1, W)
Y0W <- data.frame(Y = Y, A = 0, W)

############################
## Statistical Parameters ##
############################

# Var_0(Y \mid A, W)
# YAW$resid2 <- (Y - mu_0)^2
# var0Y1W_fit <- lm(resid2 ~ W1 + W2 + W3, data = YAW |> filter(A==1))
# var0Y0W_fit <- lm(resid2 ~ W1 + W2 + W3, data = YAW |> filter(A==0))
# var0Y1W <- predict(var0Y1W_fit, newdata = data.frame(W1, W2, W3))
# var0Y0W <- predict(var0Y0W_fit, newdata = data.frame(W1, W2, W3))
# Under homoskedasticity
var0Y1W <- sigma^2
var0Y0W <- sigma^2

# c_0(W) := \frac{Var(Y | A=1,W)(1-\pi_0(W)) + Var(Y | A=0,W)\pi_0(W)}{\pi_0(W)(1-\pi_0(W))}
c0W <- (var0Y1W*(1-pi0) + var0Y0W*pi0) / (pi0 * (1-pi0))

# E_0[ c_0(W) | V]
# Derivation Note:
# 1. Under homoskedasticity (Var = sigma^2), c_0(W) simplifies to:
#    c_0(W) = sigma^2 * [ (1-pi_0)/pi_0 + pi_0/(1-pi_0) ]
#    Since pi_0 = expit(logit_pi), this is sigma^2 * (exp(-logit_pi) + exp(logit_pi))
#    Which is equivalent to: sigma^2 * 2 * cosh(gamma0 + gamma1*W1 + gamma2*W2 + gamma3*W3)
# 2. To find E[c_0(W) | V=v], we integrate over W2 ~ U(-1, 1) and W3 ~ U(-5, 5).
# 3. The integral of cosh(a + bx) over U(L, U) is: 
#    [sinh(a + bU) - sinh(a + bL)] / (b * (U - L))
#    Using sum-to-product identities, this yields the sinh(x)/x (sinc) terms.
E_c0_V_closed <- function(v) {
  
  sinh_over_x <- function(x) ifelse(abs(x) < 1e-12, 1, sinh(x) / x)
  
  a <- gamma0 + gamma1 * v
  K <- sinh_over_x(gamma2) * sinh_over_x(5 * gamma3)
  sigma^2 * (2 + 2 * K * cosh(a))
}
E_c0_V_closed <- Vectorize(E_c0_V_closed, vectorize.args = "v")
E_c0_V <- E_c0_V_closed(V) # FKSUM::fk_regression(V, c0W, h = "cv", type = 'loc-lin')

# Var(\tau_0(W) | V)
# Under this DGM, Var(\tau_0(W) | V) = 0 + 0.25 * E[cos^2(pi W_2)] = 0.25 * 0.5
var0_tau0_V <- 0.125 # FKSUM::fk_regression(V, (tau0 - bar_tau0)^2, ngrid = 5000)

# \bar{c}_0 := E_0[ c_0(W) | V] + Var(\tau_0(W) | V)
barc0V <- E_c0_V + var0_tau0_V

# E[\bar{c}_0(V) \mid \bar{\tau}_0(V)]
E_barc0_bar_tau0 <- FKSUM::fk_regression(bar_tau0, barc0V, h = "cv", type = 'loc-lin')
E_barc0_given_bartau0 <- function(t) {
  # 1. Map bartau0 back to the two possible roots of V in [0, 1]
  # bartau0 = 0.5 * sin(2 * pi * V) => sin(2 * pi * V) = 2t
  # We use pmin/pmax to prevent NaN from tiny floating point overflows
  theta <- asin(pmin(pmax(2 * t, -1), 1))
  
  v1 <- (theta / (2 * pi)) %% 1
  v2 <- ((pi - theta) / (2 * pi)) %% 1
  
  # 2. Compute bar_c0(V) for both roots
  # Uses your existing global: E_c0_V_closed(v) and var0_tau0_V
  c_v1 <- E_c0_V_closed(v1) + var0_tau0_V
  c_v2 <- E_c0_V_closed(v2) + var0_tau0_V
  
  # 3. Conditional expectation is the arithmetic mean of the two branches
  return(0.5 * (c_v1 + c_v2))
}
E_barc0_given_bartau0 <- Vectorize(E_barc0_given_bartau0, vectorize.args = "t")
# E[\bar{c}_0(W) | \bar{\tau}_0 = t] evaluated on param_tvals via interpolation
E_barc0_bar_tau0_t_vals <- E_barc0_given_bartau0(param_tvals)

# True V-specific CATE density at t over param_tvals
f_0bar_tau0_t_vals <- ifelse(
  abs(param_tvals) < 0.5,
  2 / (pi * sqrt(1 - 4 * param_tvals^2)),
  0
)

# True V-specific CATE CDF at t over param_tvals
bar_theta0_t_vals <- ifelse(
  param_tvals < -0.5,
  0,
  ifelse(
    param_tvals > 0.5,
    1,
    0.5 + (1 / pi) * asin(2 * param_tvals)
  )
)

parameter_df <- tibble(
  t = param_tvals
) %>%
  mutate(
    # \bar{\theta}_0 (CDF of V-specific CATE)
    bar_theta0 = bar_theta0_t_vals,
    
    # \bar{\Psi}_t(P_0) (V-specific CATE primitive)
    bar_Psi_t_P0 = map_dbl(t, ~ mean((.x - bar_tau0) * (bar_tau0 <= .x))),
    
    # \bar{D}_t(P_0) (true SD of IF)
    SD_bar_D_t0 = map_dbl(
      t,
      ~ sd(
        -((bar_tau0 < .x) + 0.5 * (bar_tau0 == .x)) *
          ((A - pi0) / (pi0 * (1 - pi0)) * (Y - mu_0) + tau0 - bar_tau0) +
          (bar_tau0 <= .x) * (.x - bar_tau0) -
          mean((.x - bar_tau0) * (bar_tau0 <= .x))
      )
    ),
    
    # V-specific CATE density at t (already evaluated on param_tvals)
    f_0bar_tau0_t = f_0bar_tau0_t_vals,
    
    # E[\bar{c}_0(W) | \bar{\tau}_0 = t] on param_tvals
    E_barc0_bar_tau0_t = E_barc0_bar_tau0_t_vals,
    
    # \bar{\kappa}_0(t)
    bar_kappa0 = E_barc0_bar_tau0_t * f_0bar_tau0_t,
    
    # Chernoff scaling constant: \bar{\rho}_0(t)
    bar_rho0 = (4 * f_0bar_tau0_t * bar_kappa0)^(1/3)
  )


# Grid over W1 in [0,1] and W2 in [-1,1]
W1_grid <- seq(0, 1, length.out = 51)
W2_grid <- seq(-1, 1, length.out = 51)

# Create matrix of tau_0(W1, W2) values
# tau_0(W) = 0.5 * sin(2*pi*W1) + 0.5 * cos(pi*W2)
tau_mat <- outer(
  W1_grid,
  W2_grid,
  function(w1, w2) 0.5 * sin(2 * pi * w1) + 0.5 * cos(pi * w2)
)

## 3D surface plot with sign-based coloring

plot_ly() |>
  add_surface(
    x = ~W1_grid,
    y = ~W2_grid,
    z = ~tau_mat,
    colorscale = list(
      c(0.0, "red"),      # adverse effect (red)
      c(0.5, "#F0F0F0"),  # around zero (light gray)
      c(1.0, "blue")      # positive effect (blue)
    )
  ) |>
  layout(
    scene = list(
      xaxis = list(title = "W1"),
      yaxis = list(title = "W2"),
      zaxis = list(title = latex2exp::TeX("$\\tau_0(W)$"))
    ),
    title = "True CATE surface: tau_0(W) = 0.5 sin(2*pi*W1) + 0.5 cos(pi*W2)"
  )
## Heatmap with the same sign-based color palette

plot_ly(
  x = ~W1_grid,
  y = ~W2_grid,
  z = ~tau_mat
) |>
  add_heatmap(
    colorscale = list(
      c(0.0, "red"),      # negative (red)
      c(0.5, "#F0F0F0"),  # around zero (light gray)
      c(1.0, "blue")      # positive (blue)
    )
  ) |>
  layout(
    xaxis = list(title = "W1"),
    yaxis = list(title = "W2"),
    title = "True CATE heatmap: tau_0(W)"
  )
data.frame(
  V = V,
  bar_tau0 = bar_tau0
) %>%
  ggplot(aes(x = V, y = bar_tau0)) +
  geom_line() +
  geom_hline(yintercept = 0, color = "red", linetype = "dashed") +
  labs(
    x = expression(V == W[1]),
    y = expression(bar(tau)[0](V)),
    title = expression("True V-specific CATE " ~ bar(tau)[0](V))
  ) +
  theme_bw() +
  theme(aspect.ratio = 1)



ggplot(parameter_df, aes(x=t, y = bar_theta0)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$t$"),
    y = latex2exp::TeX("$\\bar{\\theta}_P(t)$"),
    title = "CDF of V-specific CATE",
    subtitle = latex2exp::TeX("$P(\\bar{\\tau}_P(V) \\leq t)$")
  ) +
  theme(aspect.ratio = 1)

results_theta_bias <- results_theta %>%
  select(
    t, n, grid,
    bar_theta_n, bar_theta_os, # bar_theta_tmle,
    E_barcn_bar_taun_t, f_nbar_taun, # bar_kappan,
    bar_rhon
  ) %>%
  left_join(
    parameter_df,
    by = c("t")
  ) %>%
  group_by(t, n, grid) %>%
  summarise(
    across(where(is.numeric), ~ mean(.x, na.rm = TRUE)),
    .groups = "drop"
  ) %>%
  mutate(
    bias_n = bar_theta_n - bar_theta0,
    bias_os = bar_theta_os - bar_theta0,
    # bias_tmle = bar_theta_tmle - bar_theta0,
    cbrt_n_bias_n    = n^(1/3) * bias_n,
    cbrt_n_bias_os    = n^(1/3) * bias_os,
    # cbrt_n_bias_tmle  = n^(1/3) * bias_tmle,
    # chernoff constant parameters
    bias_E_hatc = E_barcn_bar_taun_t - E_barc0_bar_tau0_t,
    bias_f_bartau = f_nbar_taun - f_0bar_tau0_t,
    # bias_bar_kappa = bar_kappan - bar_kappa0,
    bias_bar_rho = bar_rhon - bar_rho0
  )

# \bar{\theta}_n - \bar{\theta}(P_0)
results_theta_bias %>%
  select(
    t, n, grid,
    bias_n, bias_os, # bias_tmle
  ) %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(aes(y = bias_n, color = "Plug-in", linetype = "Plug-in")) +
  geom_line(aes(y = bias_os, color = "One-step", linetype = "One-step")) +
  # geom_line(aes(y = bias_tmle, color = "TMLE", linetype = "TMLE")) +
  facet_grid(t ~ grid) +
  theme_bw() +
  geom_hline(yintercept = 0) +
  scale_x_log10(
    breaks = scales::trans_breaks("log10", function(x) 10^x),
    labels = scales::trans_format("log10", scales::math_format(10^.x))
  ) +
  labs(
    x = "log_10(n)",
    y = "Mean Bias",
    color = "Estimator",
    linetype = "Estimator",
    title = "Mean Bias of V-specific CATE CDF by t"
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2"
    # "TMLE" = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid"
    # "TMLE" = "longdash"
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )
results_theta_bias %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(aes(
    y = cbrt_n_bias_n,
    color = "Plug-in",
    linetype = "Plug-in"),
    linewidth = 1
  ) +
  geom_point(aes(
    y = cbrt_n_bias_n,
    color = "Plug-in",
    shape = "Plug-in"),
    size = 2
  ) +
  geom_line(aes(
    y = cbrt_n_bias_os,
    color = "One-step",
    linetype = "One-step"),
    linewidth = 1
  ) +
  geom_point(aes(
    y = cbrt_n_bias_os,
    color = "One-step",
    shape = "One-step"),
    size = 2
  ) +
  # geom_line(aes(
  #   y = cbrt_n_bias_tmle,
  #   color = "TMLE",
  #   linetype = "TMLE"),
  #   linewidth = 1
  # ) +
  facet_grid(grid ~ t, scales = "free_y") +
  theme_bw() +
  geom_hline(yintercept = 0) +
  labs(
    x = "Sample Size (n)",
    y = latex2exp::TeX("$n^{1/3} ( \\frac{1}{NSIM} \\sum_i \\theta_{t,n,i} - \\theta_{t,0} )$"),
    color = "Estimator",
    linetype = "Estimator",
    shape = "Estimator",
    title = latex2exp::TeX("$n^{1/3}$-Scaled Mean Bias of V-specific CATE CDF")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE" = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE" = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in" = 16,  # filled circle
    "One-step" = 17, # filled triangle
    "TMLE"    = 15   # filled square
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )
# compute Monte Carlo SEs by (t, n)
results_theta_se <- results_theta %>%
  group_by(t, n, grid) %>%
  summarise(
    se_Pn   = sd(bar_theta_n, na.rm = TRUE),
    se_os   = sd(bar_theta_os, na.rm = TRUE)
    # se_tmle = sd(bar_theta_tmle, na.rm = TRUE)
  ) %>%
  mutate(
    cbrt_n_se_Pn   = n^(1/3) * se_Pn,
    cbrt_n_se_os   = n^(1/3) * se_os
    # cbrt_n_se_tmle = n^(0/1) * se_tmle
  ) %>%
  ungroup()

# plot root-n scaled Monte Carlo SE
results_theta_se %>%
  filter(
    near(t, t_theta_subset[1]) | near(t, t_theta_subset[2]) | near(t, t_theta_subset[3])
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(aes(y = cbrt_n_se_Pn, color = "Plug-in", linetype = "Plug-in")) +
  geom_line(aes(y = cbrt_n_se_os, color = "One-step", linetype = "One-step")) +
  # geom_line(aes(y = cbrt_n_se_tmle, color = "TMLE", linetype = "TMLE")) +
  facet_wrap(~ t) +
  geom_hline(yintercept = 0, linetype = "dashed") +
  theme_bw() +
  labs(
    x = "Sample size (n)",
    y = latex2exp::TeX("Monte Carlo SE"),
    color = "Estimator",
    linetype = "Estimator",
    title = latex2exp::TeX("Monte Carlo SE of $\\bar{\\theta}_t(P_n)$ by t")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2"
    # "TMLE" = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid"
    # "TMLE" = "longdash"
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )
# compute Monte Carlo SEs by (t, n)
results_theta_scaled_se <- results_theta %>%
  group_by(t, n, grid) %>%
  summarise(
    se_Pn   = sd(bar_theta_n, na.rm = TRUE),
    se_os   = sd(bar_theta_os, na.rm = TRUE),
    # se_tmle = sd(bar_theta_tmle, na.rm = TRUE)
  ) %>%
  mutate(
    cbrt_n_se_Pn   = n^(1/3) * se_Pn,
    cbrt_n_se_os   = n^(1/3) * se_os,
    # cbrt_n_se_tmle = n^(1/3) * se_tmle
  ) %>%
  ungroup()

# plot root-n scaled Monte Carlo SE
results_theta_scaled_se %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(
    aes(y = cbrt_n_se_Pn, color = "Plug-in", linetype = "Plug-in"),
    linewidth = 1
  ) +
  geom_point(
    aes(y = cbrt_n_se_Pn, color = "Plug-in", shape = "Plug-in"),
    size = 2
  ) +
  geom_line(
    aes(y = cbrt_n_se_os, color = "One-step", linetype = "One-step"),
    linewidth = 1
  ) +
  geom_point(
    aes(y = cbrt_n_se_os, color = "One-step", shape = "One-step"),
    size = 2
  ) +
  # geom_line(
  #   aes(y = cbrt_n_se_tmle, color = "TMLE", linetype = "TMLE"),
  #   linewidth = 1
  # ) +
  facet_grid(grid ~ t) +
  geom_hline(
    yintercept = 0,
    linetype = "dashed"
  ) +
  theme_bw() +
  labs(
    x = "Sample size (n)",
    y = latex2exp::TeX("$n^{1/3} \\cdot$ Monte Carlo SE"),
    color    = "Estimator",
    linetype = "Estimator",
    shape    = "Estimator",
    title = latex2exp::TeX("$n^{1/3}$-scaled Monte Carlo SE of $\\bar{\\theta}_t(P_n)$")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE" = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE" = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in" = 16,  # filled circle
    "One-step" = 17, # filled triangle
    "TMLE"    = 15   # filled square
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

# Compute scaled MSE for the three estimators
results_theta_mse <- results_theta %>%
  select(
    t, n, grid,
    bar_theta_n, bar_theta_os, # bar_theta_tmle
  ) %>%
  left_join(
    parameter_df %>% select(t, bar_theta0),
    by = c("t")
  ) %>%
  pivot_longer(
    cols = contains("bar_theta_"),
    names_to = "estimator",
    values_to = "bar_theta_est"
  ) %>%
  mutate(
    sq_err = (bar_theta_est - bar_theta0)^2,
  ) %>%
  group_by(t, n, grid, estimator) %>%
  summarise(mse = mean(sq_err, na.rm = TRUE), .groups="drop") %>%
  mutate(sc_mse = n^(2/3) * mse) %>%
  mutate(
    estimator = recode(
      estimator,
      bar_theta_n    = "Plug-in",
      bar_theta_os   = "One-step"
      # bar_theta_tmle = "TMLE"
    )
  )

results_theta_mse %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(aes(
    x = n, y = sc_mse,
    color = estimator, shape = estimator, linetype = estimator
  )) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0, color = "black") +
  facet_grid(grid ~ t, scales = "free_y") +
  theme_bw() +
  labs(
    x = "Sample Size (n)",
    y = latex2exp::TeX("$n^{2/3} \\cdot$ MSE"),
    color = "Estimator",
    linetype = "Estimator",
    shape = "Estimator",
    title = latex2exp::TeX("Sample Size Scaled Mean Squared Error")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE"    = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE"    = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in" = 16,  # filled circle
    "One-step" = 17, # filled triangle
    "TMLE"    = 15   # filled square
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

results_theta_coverage <- results_theta %>%
  left_join(
    parameter_df %>% select(t, bar_theta0),
    by = c("t")
  ) %>%
  select(
    t, n, grid,
    bar_theta0,
    bar_theta_n, bar_theta_os, #bar_theta_tmle,
    bar_rhon, bar_rhon_plugin
  ) %>%
  # Long format for the three estimators' point estimates
  pivot_longer(
    cols         = contains("bar_theta_"),
    names_to     = "estimator",
    names_prefix = "bar_theta_",
    values_to    = "theta_hat"
  ) %>%
  mutate(
    # use plug-in rho for plug-in
    bar_rhon = if_else(estimator == "n", bar_rhon_plugin, bar_rhon),
    
    # Chernoff SE (same bar_rhon used for all estimators)
    se_chernoff = bar_rhon / n^(1/3),
    se_chernoff_plugin = bar_rhon / n^(1/3),
    
    # 97.5% Chernoff quantile with the factor 2 built in
    q975        = fast_qchern(0.975),
    
    # Wald-type Chernoff CIs
    ci_lower    = theta_hat - q975 * se_chernoff,
    ci_upper    = theta_hat + q975 * se_chernoff,
    
    # Coverage indicator
    covered     = (ci_lower <= bar_theta0 & bar_theta0 <= ci_upper),
    
    # Nice estimator labels for plotting
    estimator   = dplyr::recode(
      estimator,
      "n"    = "Plug-in",
      "os"   = "One-step",
      "tmle" = "TMLE"
    )
  ) %>%
  group_by(n, t, grid, estimator) %>%
  summarise(
    coverage = mean(covered, na.rm = TRUE),
    .groups  = "drop"
  )

results_theta_coverage %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(
    .,
    aes(
      x        = n,
      y        = coverage,
      color    = estimator,
      linetype = estimator,
      shape    = estimator
    )
  ) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0.95, linetype = "dashed", color = "black") +
  facet_grid(grid ~ t) +
  theme_bw() +
  labs(
    x        = "Sample Size",
    y        = "Empirical 95% Coverage",
    color    = "Estimator",
    linetype = "Estimator",
    shape    = "Estimator",
    title    = latex2exp::TeX("Empirical Coverage of 95\\% Confidence Intervals for $\\bar{\\theta}_{t,0}$")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE"    = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE"    = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in" = 16,  # filled circle
    "One-step" = 17, # filled triangle
    "TMLE"    = 15   # filled square
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x  = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

results_theta_coverage %>%
  ggplot(aes(
    x = t, y = coverage,
    color    = estimator,
    linetype = estimator,
    shape    = estimator
  )) +
  geom_line(linewidth = 1) +
  geom_hline(yintercept = 0.95) +
  facet_grid(~ n) +
  scale_x_continuous(
    limits = c(min(t_theta_vals), max(t_theta_vals)),
    breaks = seq(min(t_theta_vals), max(t_theta_vals)+0.2, by = 0.2)
  ) +
  labs(
    x        = "t",
    y        = "Empirical 95% Coverage",
    color    = "Estimator",
    linetype = "Estimator",
    shape    = "Estimator",
    title    = latex2exp::TeX("Empirical Coverage of 95\\% Confidence Intervals for $\\bar{\\theta}_{t,0}$")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE"    = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE"    = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in" = 16,  # filled circle
    "One-step" = 17, # filled triangle
    "TMLE"    = 15   # filled square
  )) +
  theme_bw() +
  theme(
    aspect.ratio = 1,
    axis.text.x  = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )
results_theta_ciwidth <- results_theta %>%
  left_join(
    parameter_df %>% select(t, bar_theta0),
    by = c("t")
  ) %>%
  select(
    t, n, grid,
    bar_theta0,
    bar_theta_n, bar_theta_os, # bar_theta_tmle,
    bar_rhon, bar_rhon_plugin
  ) %>%
  pivot_longer(
    cols         = contains("bar_theta_"),
    names_to     = "estimator",
    names_prefix = "bar_theta_",
    values_to    = "theta_hat"
  ) %>%
  mutate(
    # use plug-in rho for plug-in
    bar_rhon = if_else(estimator == "n", bar_rhon_plugin, bar_rhon),
    
    # Chernoff SE
    se_chernoff = bar_rhon / n^(1/3),
    
    # 97.5% Chernoff quantile with factor 2 built in (your convention)
    q975     = fast_qchern(0.975),
    
    # Wald-type Chernoff CIs
    ci_lower = theta_hat - q975 * se_chernoff,
    ci_upper = theta_hat + q975 * se_chernoff,
    
    # CI width
    ci_width = ci_upper - ci_lower,
    
    # nice labels
    estimator = dplyr::recode(
      estimator,
      "n"    = "Plug-in",
      "os"   = "One-step",
      "tmle" = "TMLE"
    )
  ) %>%
  group_by(n, t, grid, estimator) %>%
  summarise(
    mean_ci_width   = mean(ci_width, na.rm = TRUE),
    median_ci_width = median(ci_width, na.rm = TRUE),
    .groups = "drop"
  )

results_theta_ciwidth %>%
  filter(t %in% t_theta_subset) %>%
  ggplot(
    aes(
      x        = n,
      y        = mean_ci_width,   # swap to median_ci_width if you prefer
      color    = estimator,
      linetype = estimator,
      shape    = estimator
    )
  ) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0, color = "black") +
  facet_grid(grid ~ t) +
  theme_bw() +
  labs(
    x        = "Sample Size",
    y        = "Mean 95% CI Width",
    color    = "Estimator",
    linetype = "Estimator",
    shape    = "Estimator",
    title    = latex2exp::TeX("Mean Width of 95\\% Confidence Intervals for $\\bar{\\theta}_{t,0}$")
  ) +
  scale_color_manual(values = c(
    "Plug-in"  = "plum",
    "One-step" = "#0072B2",
    "TMLE"     = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in"  = "dotted",
    "One-step" = "solid",
    "TMLE"     = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in"  = 16,
    "One-step" = 17,
    "TMLE"     = 15
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x  = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

# Compute n^(1/3)-standardized quantities
results_theta_standardized <- results_theta %>%
  left_join(
    parameter_df %>% select(t, bar_theta0),
    by = c("t")
  ) %>%
  mutate(
    z_Pn   = n^(1/3) * (bar_theta_n   - bar_theta0) / bar_rhon_plugin,
    z_os   = n^(1/3) * (bar_theta_os - bar_theta0) / bar_rhon,
    # z_tmle = n^(1/3) * (bar_theta_tmle - bar_theta0) / bar_rhon
  ) %>%
  select(t, n, grid, starts_with("z_")) %>%
  pivot_longer(
    cols = starts_with("z_"),
    names_to = "estimator",
    values_to = "z_value"
  ) %>%
  mutate(
    estimator = recode(estimator,
                       "z_Pn"   = "Plug-in",
                       "z_os"   = "One-step",
                       "z_tmle" = "TMLE"
    )
  )

# Standard normal reference curve
chern_ref <- tibble(
  x = seq(-4, 4, length.out = 400),
  density = ChernoffDist::dChern(seq(-4, 4, length.out = 400))
)


dist_plot_for_grid <- function(dat, grid_val, chern_ref) {
  dat %>%
    filter(grid == grid_val) %>%
    ggplot(aes(x = z_value, fill = n)) +
    geom_histogram(
      aes(y = after_stat(density)),
      position = "identity",
      alpha = 0.4,
      color = "black",
      bins = 40
    ) +
    geom_line(
      data = chern_ref,
      aes(x = x, y = density),
      color = "black",
      linewidth = 1,
      linetype = "solid",
      inherit.aes = FALSE
    ) +
    geom_vline(xintercept = 0, linetype = "dashed") +
    facet_grid(estimator ~ t, scales = "free") +
    theme_bw() +
    labs(
      x = latex2exp::TeX("$n^{1/3} (\\bar{\\theta}_{t,n} - \\bar{\\theta}_{t,0}) / \\rho_n$"),
      y = "Density",
      fill = "Sample size (n)",
      title = paste("Empirical Distributions with Standard Chernoff Overlay — grid:", grid_val),
      subtitle = latex2exp::TeX("Standardized Estimators of $\\bar{\\theta}_{t,0}$")
    ) +
    scale_fill_brewer(palette = "Set2") +
    theme(
      strip.text = element_text(size = 10),
      aspect.ratio = 1,
      legend.position = "bottom"
    )
}

# Make one plot per grid ---
grid_vals <- sort(unique(results_theta_standardized$grid))
names(grid_vals) <- as.character(grid_vals)

subset_dist <- results_theta_standardized %>%
  filter(
    n %in% 10000,
    t %in% c(-0.25, 0, 0.25)
  ) %>%
  mutate(n = factor(n))   # <-- key

dist_plots <- map(
  grid_vals,
  ~ dist_plot_for_grid(subset_dist, .x, chern_ref)
)
# Print them all
walk(dist_plots, print)
results_theta_qq <- results_theta %>%
  left_join(
    parameter_df %>% select(t, bar_theta0),
    by = c("t")
  ) %>%
  mutate(
    # Chernoff statistics:
    chernoff_stat_n    = n^(1/3) * (bar_theta_n - bar_theta0) / bar_rhon,    
    chernoff_stat_os   = n^(1/3) * (bar_theta_os - bar_theta0) / bar_rhon,   
    # chernoff_stat_tmle = n^(1/3) * (bar_theta_tmle - bar_theta0) / bar_rhon
  ) %>%
  select(t, n, grid, contains("chernoff_stat_")) %>%
  drop_na() %>%
  pivot_longer(
    contains("chernoff_stat_"), #c(chernoff_stat_n, chernoff_stat_os, chernoff_stat_tmle),
    names_to  = "estimator",
    values_to = "sample"
  ) %>%
  group_by(n, t, grid, estimator) %>%
  arrange(sample, .by_group = TRUE) %>%          # sort empirical quantiles
  mutate(
    prob_seq = ppoints(dplyr::n()),              # ppoints per (n, t, estimator)
    theor    = fast_qchern(prob_seq)             # standardized chernoff
  ) %>%
  ungroup()


# 2) Plot function for one grid value
qq_plot_for_grid <- function(dat, grid_val) {
  dat %>%
    filter(grid == grid_val) %>%
    ggplot(aes(x = theor, y = sample, color = estimator)) +
    geom_point(alpha = 0.5, size = 0.8) +
    geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
    facet_grid(n ~ t, scales = "free_y") +
    labs(
      x = "Theoretical Quantiles",
      y = "Empirical Quantiles",
      title = paste("Chernoff QQ plots — grid:", grid_val)
    ) +
    scale_color_discrete(
      name   = "Estimator",
      breaks = c("chernoff_stat_n", "chernoff_stat_os", "chernoff_stat_tmle"),
      labels = c("Plug-in", "One-step", "TMLE")
    ) +
    theme_bw() +
    theme(
      aspect.ratio    = 1,
      legend.position = "bottom"
    )
}

# 3) Make the 3 plots for each grid
grid_vals <- sort(unique(results_theta_qq$grid))
names(grid_vals) <- as.character(grid_vals)
qq_plots <- map(grid_vals, ~ qq_plot_for_grid(
  results_theta_qq %>%
    filter(
      n %in% c(250, 2500, 10000),
      t %in% t_theta_subset
    ),
  .x))

# Or print them all
walk(qq_plots, print)


ggplot(parameter_df, aes(x=t, y = bar_Psi_t_P0)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$t$"),
    y = latex2exp::TeX("$\\bar{\\Psi}_P(t)$"),
    title = "Primitive of the V-specific CATE CDF",
    subtitle = latex2exp::TeX("$E[1\\{\\bar{\\tau}_P(V) \\leq t\\} \\{t - \\bar{\\tau}_P(V)\\} ]$")
  ) +
  theme(aspect.ratio = 1)

results_psi_bias <- results_psi %>%
  left_join(
    parameter_df,
    by = c("t")
  ) %>%
  group_by(t, n, grid) %>%
  summarise(across(where(is.numeric), ~ mean(.x, na.rm = TRUE)), .groups = "drop") %>%
  ungroup() %>%
  mutate(
    bias_n = bar_Psi_Pn - bar_Psi_t_P0,
    bias_os = bar_Psi_t_os - bar_Psi_t_P0,
    # bias_tmle = bar_Psi_t_tmle - bar_Psi_t_P0,
    sqrt_n_bias_n    = n^(1/2) * bias_n,
    sqrt_n_bias_os    = n^(1/2) * bias_os
    # sqrt_n_bias_tmle  = n^(1/2) * bias_tmle
  )

# \bar{\Psi}_n - \bar{\Psi}(P_0)
results_psi_bias %>%
  filter(
    t %in% t_psi_vals
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(aes(y = bias_n, color = "Plug-in", linetype = "Plug-in"), linewidth = 1) +
  geom_line(aes(y = bias_os, color = "One-step", linetype = "One-step"), linewidth = 1) +
  # geom_line(aes(y = bias_tmle, color = "TMLE", linetype = "TMLE"), linewidth = 1) +
  facet_grid(grid ~ t, scales = "free_y") +
  theme_bw() +
  geom_hline(yintercept = 0) +
  scale_x_log10(
    breaks = scales::trans_breaks("log10", function(x) 10^x),
    labels = scales::trans_format("log10", scales::math_format(10^.x))
  ) +
  labs(
    x = "log_10(n)",
    y = "Mean Bias",
    color = "Estimator",
    linetype = "Estimator",
    title = "Mean Bias of V-specific CATE CDF Primitive by t"
  ) +
  scale_color_manual(values = c("Plug-in" = "plum",
                                "One-step" = "#0072B2",
                                "TMLE" = "#D55E00")) +
  scale_linetype_manual(values = c("Plug-in" = "dotted",
                                   "One-step" = "solid",
                                   "TMLE" = "longdash")) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

# \sqrt{n} (\bar{\Psi}(P_0) - \bar{\Psi}_n)
results_psi_bias %>%
  filter(
    t %in% t_psi_vals
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(
    aes(y = sqrt_n_bias_n, color = "Plug-in", linetype = "Plug-in"),
    linewidth = 1
  ) +
  geom_point(
    aes(y = sqrt_n_bias_n, color = "Plug-in", shape = "Plug-in"),
    size = 2
  ) +
  geom_line(
    aes(y = sqrt_n_bias_os, color = "One-step", linetype = "One-step"),
    linewidth = 1
  ) +
  geom_point(
    aes(y = sqrt_n_bias_os, color = "One-step", shape = "One-step"),
    size = 2
  ) +
  # geom_line(
  #   aes(y = sqrt_n_bias_tmle, color = "TMLE", linetype = "TMLE"),
  #   linewidth = 1
  # ) +
  facet_grid(grid ~ t, scales = "free_y") +
  theme_bw() +
  geom_hline(yintercept = 0) +
  scale_x_log10(
    breaks = scales::trans_breaks("log10", function(x) 10^x),
    labels = scales::trans_format("log10", scales::math_format(10^.x))
  ) +
  labs(
    x = "log_10(n)",
    y = latex2exp::TeX("$\\sqrt{n} (\\sum_{i=1}^{NSIM} \\psi_{t,n,i} - \\psi_{t,0} )$"),
    color = "Estimator",
    linetype = "Estimator",
    title = latex2exp::TeX("Root-n Scaled Mean Bias of V-specific CATE CDF Primitive faceted by $t \\in (2,3)$")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE" = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE" = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in"  = 16,
    "One-step" = 17,
    "TMLE"     = 15
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )
# compute Monte Carlo SEs by (t, n)
results_psi_se <- results_psi %>%
  group_by(t, n, grid) %>%
  summarise(
    se_Pn   = sd(bar_Psi_Pn, na.rm = TRUE) / sqrt(n()),
    se_os   = sd(bar_Psi_t_os, na.rm = TRUE) / sqrt(n())
    # se_tmle = sd(bar_Psi_t_tmle, na.rm = TRUE) / sqrt(n())
  ) %>%
  mutate(
    sqrt_n_se_Pn   = sqrt(n) * se_Pn,
    sqrt_n_se_os   = sqrt(n) * se_os
    # sqrt_n_se_tmle = sqrt(n) * se_tmle
  ) %>%
  ungroup()

# plot root-n scaled Monte Carlo SE
results_psi_se %>%
  filter(
    t %in% t_psi_vals
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(
    aes(y = sqrt_n_se_Pn, color = "Plug-in", linetype = "Plug-in"),
    linewidth = 1
  ) +
  geom_point(
    aes(y = sqrt_n_se_Pn, color = "Plug-in", shape = "Plug-in"),
    size = 2
  ) +
  geom_line(
    aes(y = sqrt_n_se_os, color = "One-step", linetype = "One-step"),
    linewidth = 1
  ) +
  geom_point(
    aes(y = sqrt_n_se_os, color = "One-step", shape = "One-step"),
    size = 2
  ) +
  # geom_line(
  #   aes(y = sqrt_n_se_tmle, color = "TMLE", linetype = "TMLE"),
  #   linewidth = 1
  # ) +
  facet_grid(grid ~ t, scales = "free_y") +
  geom_hline(yintercept = 0, linetype = "dashed") +
  theme_bw() +
  labs(
    x = "Sample size (n)",
    y = latex2exp::TeX("$\\sqrt{n}$-scaled Monte Carlo SE"),
    color = "Estimator",
    linetype = "Estimator",
    title = latex2exp::TeX("Root-n Scaled Monte Carlo SE of $\\bar{\\Psi}_t(P_n)$ by t")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE" = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE" = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in"  = 16,
    "One-step" = 17,
    "TMLE"     = 15
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )
# Compute MSE for the three estimators
results_psi_mse <- results_psi %>%
  left_join(
    parameter_df,
    by = c("t")
  ) %>%
  mutate(
    sq_err_n    = (bar_Psi_Pn    - bar_Psi_t_P0)^2,
    sq_err_os   = (bar_Psi_t_os   - bar_Psi_t_P0)^2
    # sq_err_tmle = (bar_Psi_t_tmle - bar_Psi_t_P0)^2
  ) %>%
  group_by(t, grid, n) %>%
  summarise(
    mse_n    = mean(sq_err_n, na.rm = TRUE),
    mse_os   = mean(sq_err_os, na.rm = TRUE),
    # mse_tmle = mean(sq_err_tmle, na.rm = TRUE),
    .groups = "drop"
  ) %>%
  pivot_longer(
    cols = starts_with("mse_"),
    names_to = "estimator",
    values_to = "mse"
  ) %>%
  mutate(
    sc_mse = sqrt(mse) * mse,
    estimator = recode(
      estimator,
      mse_n    = "Plug-in",
      mse_os   = "One-step",
      mse_tmle = "TMLE"
    )
  )

results_psi_mse %>%
  filter(
    t %in% t_psi_subset
  ) %>%
  ggplot(aes(x = n, y = sc_mse, color = estimator, shape = estimator, linetype = estimator)) +
  geom_point(size = 2) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  facet_grid(grid ~ t, scales = "free_y") +
  theme_bw() +
  labs(
    x = "Sample Size (n)",
    y = "Mean Squared Error",
    color = "Estimator",
    linetype = "Estimator",
    shape = "Estimator",
    title = latex2exp::TeX("Mean Squared Error of $\\bar{\\Psi}_{t}$ Estimators")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE"    = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE"    = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in"  = 16,
    "One-step" = 17,
    "TMLE"     = 15
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )
# Compute empirical coverage for each estimator
results_psi_coverage <-results_psi %>%
  left_join(
    parameter_df,
    by = c("t")
  ) %>%
  select(
    t, n, grid, bar_Psi_t_P0,
    starts_with("bar_Psi_"),
    SD_bar_D_tn, SD_Psi_Pn
  ) %>%
  # Long format for the three estimators' point estimates
  pivot_longer(
    cols         = c(bar_Psi_Pn, bar_Psi_t_os), # bar_Psi_t_tmle),
    names_to     = "estimator",
    names_prefix = "bar_Psi_",
    values_to    = "Psi_hat"
  ) %>%
  mutate(
    # use plug-in SD for plug-in confidence intervals
    SD_bar_D_tn = if_else(estimator == "Pn", SD_Psi_Pn, SD_bar_D_tn),
    
    # Wald-type CIs
    ci_lower    = Psi_hat   - 1.96 * SD_bar_D_tn / sqrt(n),
    ci_upper    = Psi_hat   + 1.96 * SD_bar_D_tn / sqrt(n),
    
    # Coverage indicator
    covered     = (ci_lower <= bar_Psi_t_P0 & bar_Psi_t_P0 <= ci_upper),
    
    # Nice estimator labels for plotting
    estimator   = dplyr::recode(
      estimator,
      "Pn"    = "Plug-in",
      "t_os"   = "One-step",
      "t_tmle" = "TMLE"
    )
  ) %>%
  group_by(n, t, grid, estimator) %>%
  summarise(
    coverage = mean(covered, na.rm = TRUE),
    .groups  = "drop"
  )


results_psi_coverage %>%
  filter(
    t %in% t_psi_subset
  ) %>%
  ggplot(aes(
    x = n, y = coverage,
    color = estimator,
    shape = estimator,
    linetype = estimator
  )) +
  geom_line(linewidth = 0.9) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0.95, linetype = "dashed", color = "black") +
  facet_grid(grid ~ t, scales = "free_y") +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$\\log_{10}(n)$"),
    y = "Empirical 95% Coverage",
    color = "Estimator",
    linetype = "Estimator",
    shape = "Estimator",
    title = latex2exp::TeX("Empirical Coverage of 95\\% Wald Confidence Intervals for $\\bar{\\Psi}_t(P_n)$")
  ) +
  scale_x_log10(
    breaks = scales::trans_breaks("log10", function(x) 10^x),
    labels = scales::trans_format("log10", scales::math_format(10^.x))
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE" = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE" = "longdash"
  )) +
  coord_cartesian(ylim = c(0, 1.0)) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

results_psi_coverage %>%
  ggplot(aes(
    x = t, y = coverage,
    color    = estimator,
    linetype = estimator,
    shape    = estimator
  )) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0.95) +
  facet_grid(grid ~ n, scales = "free_y") +
  theme_bw() +
  labs(
    x        = "t",
    y        = "Empirical 95% Coverage",
    color    = "Estimator",
    linetype = "Estimator",
    shape    = "Estimator",
    title    = latex2exp::TeX("Empirical Coverage of 95\\% Confidence Intervals for $\\bar{\\Psi}_{t,0}$")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE"    = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE"    = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in" = 16,  # filled circle
    "One-step" = 17, # filled triangle
    "TMLE"    = 15   # filled square
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x  = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )
# Compute √n-standardized quantities (KEEP grid)
results_standardized <- results_psi %>%
  left_join(parameter_df, by = "t") %>%
  mutate(
    z_Pn   = sqrt(n) * (bar_Psi_Pn   - bar_Psi_t_P0) / SD_Psi_Pn,
    z_os   = sqrt(n) * (bar_Psi_t_os - bar_Psi_t_P0) / SD_bar_D_t0
    # z_tmle = sqrt(n) * (bar_Psi_t_tmle - bar_Psi_t_P0) / SD_bar_D_t0
  ) %>%
  select(t, n, grid, starts_with("z_")) %>%   # <-- include grid
  pivot_longer(
    cols = starts_with("z_"),
    names_to = "estimator",
    values_to = "z_value"
  ) %>%
  mutate(
    estimator = recode(estimator,
                       "z_Pn"   = "Plug-in",
                       "z_os"   = "One-step",
                       "z_tmle" = "TMLE"),
    n = factor(n)  # <-- force discrete fill everywhere
  )

# Standard normal reference curve
normal_ref <- tibble(
  x = seq(-4, 4, length.out = 400),
  density = dnorm(x)
)

# Pre-filter once
df_rootn <- results_standardized %>%
  filter(
    n %in% 10000,
    t %in% c(-0.25, 0, 0.25)
  )

# Plot function for one grid value
dist_plot_for_grid_rootn <- function(dat, grid_val, normal_ref) {
  dat %>%
    filter(grid == grid_val) %>%
    ggplot(aes(x = z_value, fill = n)) +
    geom_histogram(
      aes(y = after_stat(density)),
      position = "identity",
      alpha = 0.4,
      color = "black",
      bins = 40
    ) +
    geom_line(
      data = normal_ref,
      aes(x = x, y = density),
      color = "black", linewidth = 1, linetype = "solid",
      inherit.aes = FALSE
    ) +
    geom_vline(xintercept = 0, linetype = "dashed") +
    facet_grid(estimator ~ t, scales = "free") +
    theme_bw() +
    labs(
      x = latex2exp::TeX("$\\sqrt{n}(\\bar{\\psi}_{t,n} - \\bar{\\psi}_{t,0}) / SD(\\bar{D}_t)$"),
      y = "Density",
      fill = "Sample size (n)",
      title = paste0(
        latex2exp::TeX("Empirical Distribution of Root-n Standardized Estimators of $\\bar{\\Psi}_t(P_n)$ with $N(0,1)$ Overlay"),
        " — grid: ", grid_val
      )
    ) +
    scale_fill_brewer(palette = "Set2") +
    theme(
      strip.text = element_text(size = 10),
      aspect.ratio = 1,
      legend.position = "bottom"
    )
}

# Make one plot per grid
grid_vals <- sort(unique(df_rootn$grid))
names(grid_vals) <- as.character(grid_vals)

rootn_plots <- map(
  grid_vals,
  ~ dist_plot_for_grid_rootn(df_rootn, .x, normal_ref)
)

# Print them all
walk(rootn_plots, print)


ggplot(parameter_df, aes(x=t, y = SD_bar_D_t0)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$t$"),
    y = latex2exp::TeX("SD($\\bar{D}_{t,P}$)"),
    title = "Standard Deviation of the Gradient"
  ) +
  theme(aspect.ratio = 1)

# summarize remainders by n and t, then scale by sqrt(n)
results_remainder_summary <- results_psi %>%
  group_by(t, grid, n) %>%
  summarise(
    # mean
    mean_bar_R1 = mean(bar_R1, na.rm = TRUE),
    mean_R_decomp_1 = mean(R_decomp_1, na.rm = TRUE),
    mean_R_decomp_2 = mean(R_decomp_2, na.rm = TRUE),
    mean_R_decomp_3 = mean(R_decomp_3, na.rm = TRUE),
    mean_R_decomp_4 = mean(R_decomp_4, na.rm = TRUE),
    mean_bar_R_total = mean(bar_R_total, na.rm = TRUE),
    
    # standard deviation
    sd_bar_R1 = sd(bar_R1, na.rm = TRUE),
    sd_R_decomp_1 = sd(R_decomp_1, na.rm = TRUE),
    sd_R_decomp_2 = sd(R_decomp_2, na.rm = TRUE),
    sd_R_decomp_3 = sd(R_decomp_3, na.rm = TRUE),
    sd_R_decomp_4 = sd(R_decomp_4, na.rm = TRUE),
    sd_bar_R_total = sd(bar_R_total, na.rm = TRUE)
  ) %>%
  # scale by √n
  mutate(
    across(starts_with("mean_"), ~ sqrt(n) * .x),
    across(starts_with("sd_"), ~ sqrt(n) * .x)
  )

# plot √n-scaled empirical mean of remainders
results_remainder_summary %>%
  pivot_longer(
    cols = starts_with("mean_"),
    names_to = "remainder",
    values_to = "scaled_mean_value"
  ) %>%
  # tidy up labels
  mutate(
    remainder = recode(remainder,
                       "mean_bar_R1" = "bar_R1 (nuisance error)",
                       "mean_R_decomp_1" = "R1: underestimation",
                       "mean_R_decomp_2" = "R2: missed mass below t",
                       "mean_R_decomp_3" = "R3: overestimation",
                       "mean_R_decomp_4" = "R4: boundary term",
                       "mean_bar_R_total" = "Total remainder"
    )
  ) %>%
  # filter to subset of t
  filter(
    t %in% t_psi_subset
  ) %>%
  ggplot(aes(x = n, y = scaled_mean_value, color = remainder, linetype = remainder)) +
  geom_line(linewidth = 0.9) +
  facet_grid(grid ~ t, scales = "free_y") +
  geom_hline(yintercept = 0, linetype = "dashed") +
  theme_bw() +
  labs(
    x = "log_10(n)",
    y = latex2exp::TeX("$\\sqrt{n} \\times$ Empirical mean remainder"),
    color = "Remainder component",
    linetype = "Remainder component",
    title = latex2exp::TeX("Root-n Scaled Empirical Mean of Remainder Terms, Faceted by $t$")
  ) +
  scale_color_manual(values = c(
    "bar_R1 (nuisance error)" = "plum",
    "R1: underestimation" = "#0072B2",
    "R2: missed mass below t" = "#D55E00",
    "R3: overestimation" = "#009E73",
    "R4: boundary term" = "black",
    "Total remainder" = "purple"
  )) +
  scale_linetype_manual(values = c(
    "bar_R1 (nuisance error)" = "dotted",
    "R1: underestimation" = "solid",
    "R2: missed mass below t" = "solid",
    "R3: overestimation" = "solid",
    "R4: boundary term" = "dotdash",
    "Total remainder" = "longdash"
  )) +
  scale_x_log10(
    breaks = scales::trans_breaks("log10", function(x) 10^x),
    labels = scales::trans_format("log10", scales::math_format(10^.x))
  ) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1)
  )
# SD of remainders by n and t, then scale by sqrt(n)
# plot √n-scaled empirical mean of remainders
results_remainder_summary %>%
  pivot_longer(
    cols = starts_with("sd_"),
    names_to = "remainder",
    values_to = "scaled_sd_value"
  ) %>%
  # tidy up labels
  mutate(
    remainder = recode(remainder,
                       "sd_bar_R1" = "bar_R1 (nuisance error)",
                       "sd_R_decomp_1" = "R1: underestimation",
                       "sd_R_decomp_2" = "R2: missed mass below t",
                       "sd_R_decomp_3" = "R3: overestimation",
                       "sd_R_decomp_4" = "R4: boundary term",
                       "sd_bar_R_total" = "Total remainder"
    )
  ) %>%
  # filter to subset of t
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(aes(x = n, y = scaled_sd_value, color = remainder, linetype = remainder)) +
  geom_line(linewidth = 0.9) +
  facet_grid(grid ~ t) +
  geom_hline(yintercept = 0, linetype = "dashed") +
  theme_bw() +
  labs(
    x = "log_10(n)",
    y = latex2exp::TeX("$\\sqrt{n} \\times$ Var(Remainder)"),
    color = "Remainder component",
    linetype = "Remainder component",
    title = latex2exp::TeX("Root-n Scaled Empirical Variance of Remainder"),
    subtitle =  "Faceted by t"
  ) +
  scale_color_manual(values = c(
    "bar_R1 (nuisance error)" = "plum",
    "R1: underestimation" = "#0072B2",
    "R2: missed mass below t" = "#D55E00",
    "R3: overestimation" = "#009E73",
    "R4: boundary term" = "black",
    "Total remainder" = "purple"
  )) +
  scale_linetype_manual(values = c(
    "bar_R1 (nuisance error)" = "dotted",
    "R1: underestimation" = "solid",
    "R2: missed mass below t" = "solid",
    "R3: overestimation" = "solid",
    "R4: boundary term" = "dotdash",
    "Total remainder" = "longdash"
  )) +
  scale_x_log10(
    breaks = scales::trans_breaks("log10", function(x) 10^x),
    labels = scales::trans_format("log10", scales::math_format(10^.x))
  ) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1)
  )

# tibble(
#   "W1" = seq(0, 1, by=0.001)
# ) %>%
#   mutate(
#     bar_tau = beta1 + beta5*W1,
#     # Unif(beta1, beta1+beta5)
#     f_0bar_tau0_t = 1 / (max(bar_tau) - min(bar_tau))
#   )

ggplot(parameter_df, aes(x=t, y = f_0bar_tau0_t)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$x$"),
    y = latex2exp::TeX("$f_{\\bar{\\tau}_P}(x)$"),
    title = "Density of V-Specific CATE"
  ) +
  theme(aspect.ratio = 1)

# bias_f_bartau
results_theta_bias %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(aes(y = bias_f_bartau, color = "f_bartau", linetype = "f_bartau")) +
  facet_grid(grid ~ t) +
  theme_bw() +
  geom_hline(yintercept = 0) +
  labs(
    x = "log_10(n)",
    y = "Bias",
    color = "Estimator",
    linetype = "Estimator",
    title = latex2exp::TeX("Bias f_0")
  ) +
  scale_x_log10(
    breaks = scales::trans_breaks("log10", function(x) 10^x),
    labels = scales::trans_format("log10", scales::math_format(10^.x))
  ) +
  scale_color_manual(values = c("f_bartau" = "#D55E00")) +
  scale_linetype_manual(values = c("f_bartau" = "longdash")) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1)
  )
data.frame(
  barc0V,
  V
) %>%
  ggplot(aes(x=V, y = barc0V)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$V=W_1$"),
    y = latex2exp::TeX("$\\bar{c}_0(V)$")
  ) +
  theme(aspect.ratio = 1)

data.frame(
  E_barc0_bar_tau0_t_vals,
  param_tvals
) %>%
  ggplot(aes(x=parameter_df$t, y = parameter_df$E_barc0_bar_tau0_t)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$\\bar{\\tau}_0(V)$"),
    y = latex2exp::TeX("$E[\\bar{c}_0(V) | \\bar{\\tau}_0(v) = t]$")
  ) +
  theme(aspect.ratio = 1)

# E_barc_bartau bias
results_theta_bias %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(aes(y = bias_E_hatc, color = "bias_E_hatc_bartau", linetype = "bias_E_hatc_bartau")) +
  facet_grid(grid ~ t) +
  theme_bw() +
  geom_hline(yintercept = 0) +
  labs(
    x = "n",
    y = "Bias",
    color = "Estimator",
    linetype = "Estimator",
    title = latex2exp::TeX("Bias $E[\\bar{c}_0(V) | \\bar{\\tau}_0(v) = t]$")
  ) +
  # scale_x_log10(
  #   breaks = scales::trans_breaks("log10", function(x) 10^x),
  #   labels = scales::trans_format("log10", scales::math_format(10^.x))
  # ) +
  scale_color_manual(values = c("bias_E_hatc_bartau" = "#D55E00")) +
  scale_linetype_manual(values = c("bias_E_hatc_bartau" = "longdash")) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1)
  )
parameter_df %>%
  ggplot(aes(x=t, y = bar_rho0)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$\\bar{\\tau}_0(V)$"),
    y = latex2exp::TeX("$\\bar{\\rho}_0(t)$")
  ) +
  theme(aspect.ratio = 1)

# bar_rho bias
results_theta_bias %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(aes(y = bias_bar_rho, color = "bar_rho_n", linetype = "bar_rho_n")) +
  facet_grid(grid ~ t) +
  theme_bw() +
  geom_hline(yintercept = 0) +
  labs(
    x = "log_10(n)",
    y = "Bias",
    color = "Estimator",
    linetype = "Estimator",
    title = latex2exp::TeX("Bias \\bar{\\rho}")
  ) +
  scale_x_log10(
    breaks = scales::trans_breaks("log10", function(x) 10^x),
    labels = scales::trans_format("log10", scales::math_format(10^.x))
  ) +
  scale_color_manual(values = c("bar_rho_n" = "#D55E00")) +
  scale_linetype_manual(values = c("bar_rho_n" = "longdash")) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1)
  )